#!/usr/bin/env python
# coding: utf-8

# # 数据处理
# 数据是深度学习的基础，良好的数据输入可以对整个深度神经网络训练起到非常积极的作用。在训练前对已加载的数据集进行数据处理，可以解决诸如数据量过大、样本分布不均等问题，从而获得对训练结果更有利的数据输入。
# 
# luojianet的各个数据集类都为用户提供了多种数据处理操作，用户可以通过构建数据处理的流水线（pipeline）来定义需要使用的数据处理操作，在训练过程中，数据即可像水一样源源不断地经过数据处理pipeline流向训练系统。
# 
# luojianet目前支持如数据清洗`shuffle`、数据分批`batch`、数据重复`repeat`、数据拼接`concat`等常用数据处理操作。
# 
# > 更多数据处理操作参见[API文档](http://58.48.42.237/luojiaNet/luojiaNetapi/)。
# 
# ## 数据处理操作
# 
# ### shuffle
# 
# shuffle操作会随机打乱数据顺序，对数据集进行混洗。
# 
# 设定的`buffer_size`越大，数据混洗程度越大，同时所消耗的时间、计算资源也更大。
# 
# ![shuffle](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_shuffle.png)
# 
# 下面的样例先构建了一个随机数据集，然后对其进行混洗操作，最后展示了数据混洗前后的结果。

# In[1中低阶API实现深度学习]:


import numpy as np
import luojianet.dataset as ds

ds.config.set_seed(0)

def generator_func():
    """定义生成数据集函数"""
    for i in range(5):
        yield (np.array([i, i+1, i+2]),)

# 生成数据集
dataset = ds.GeneratorDataset(generator_func, ["data"])
for data in dataset.create_dict_iterator():
    print(data)

print("------ after processing ------")

# 执行数据清洗操作
dataset = dataset.shuffle(buffer_size=2)
for data in dataset.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，经过`shuffle`操作之后，数据顺序被打乱了。
# 
# ### batch
# 
# batch操作将数据集分批，分别输入到训练系统中进行训练，可以减少训练轮次，达到加速训练过程的目的。
# 
# ![batch](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_batch.png)
# 
# 下面的样例先构建了一个数据集，然后分别展示了丢弃多余数据与否的数据集分批结果，其中批大小为2。

# In[4自然语言]:


import numpy as np
import luojianet.dataset as ds

def generator_func():
    """定义生成数据集函数"""
    for i in range(5):
        yield (np.array([i, i+1, i+2]),)

dataset = ds.GeneratorDataset(generator_func, ["data"])
for data in dataset.create_dict_iterator():
    print(data)

# 采用不丢弃多余数据的方式对数据集进行分批
dataset = ds.GeneratorDataset(generator_func, ["data"])
dataset = dataset.batch(batch_size=2, drop_remainder=False)
print("------not drop remainder ------")
for data in dataset.create_dict_iterator():
    print(data)

# 采用丢弃多余数据的方式对数据集进行分批
dataset = ds.GeneratorDataset(generator_func, ["data"])
dataset = dataset.batch(batch_size=2, drop_remainder=True)
print("------ drop remainder ------")
for data in dataset.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，数据集大小为5，每2个分一组，不丢弃多余数据时分为3组，丢弃多余数据时分为2组，最后一条数据被丢弃。
# 
# ### repeat
# 
# repeat操作对数据集进行重复，达到扩充数据量的目的。`repeat`和`batch`操作的先后顺序会影响训练batch的数量，建议将`repeat`置于`batch`之后。
# 
# ![repeat](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_repeat.png)
# 
# 下面的样例先构建了一个随机数据集，然后将其重复2次，最后展示了重复后的数据结果。

# In[40]:


import numpy as np
import luojianet.dataset as ds

def generator_func():
    """定义生成数据集函数"""
    for i in range(5):
        yield (np.array([i, i+1, i+2]),)

# 生成数据集
dataset = ds.GeneratorDataset(generator_func, ["data"])
for data in dataset.create_dict_iterator():
    print(data)

print("------ after processing ------")

# 对数据进行数据重复操作
dataset = dataset.repeat(count=2)
for data in dataset.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，数据集被拷贝了之后扩充到原数据集后面。
# 
# ### zip
# 
# zip操作实现两个数据集的列拼接，将其合并为一个数据集。使用时需要注意以下两点：
# 
# 1中低阶API实现深度学习. 如果两个数据集的列名相同，则不会合并，请注意列的命名。
# 2高级数据集管理. 如果两个数据集的行数不同，合并后的行数将和较小行数保持一致。
# 
# ![zip](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_zip.png)
# 
# 下面的样例先构建了两个不同样本数的随机数据集，然后将其进行列拼接，最后展示了拼接后的数据结果。

# In[5]:


import numpy as np
import luojianet.dataset as ds

def generator_func():
    """定义生成数据集函数1"""
    for i in range(7):
        yield (np.array([i, i+1, i+2]),)

def generator_func2():
    """定义生成数据集函数2"""
    for _ in range(4):
        yield (np.array([1, 2]),)

print("------ data1 ------")
dataset1 = ds.GeneratorDataset(generator_func, ["data1"])
for data in dataset1.create_dict_iterator():
    print(data)

print("------ data2 ------")
dataset2 = ds.GeneratorDataset(generator_func2, ["data2"])
for data in dataset2.create_dict_iterator():
    print(data)

print("------ data3 ------")

# 对数据集1和数据集2做zip操作，生成数据集3
dataset3 = ds.zip((dataset1, dataset2))
for data in dataset3.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，数据集3由数据集1和数据集2列拼接得到，其列数为后两者之和，其行数与后两者中最小行数（数据集2行数）保持一致，数据集1中后面多余的行数被丢弃。
# 
# ### concat
# 
# concat实现两个数据集的行拼接，并将其合并为一个数据集。使用时需要注意：输入数据集中的列名、列数据类型和列数据的排列应相同。
# 
# ![concat](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_concat.png)
# 
# 下面的样例先构建了两个随机数据集，然后将其做行拼接，最后展示了拼接后的数据结果。值得一提的是，使用`+`运算符也能达到同样的效果。

# In[42]:


import numpy as np
import luojianet.dataset as ds

def generator_func():
    """定义生成数据集函数1"""
    for _ in range(2):
        yield (np.array([0, 0, 0]),)

def generator_func2():
    """定义生成数据集函数2"""
    for _ in range(2):
        yield (np.array([1, 2, 3]),)

# 生成数据集1
dataset1 = ds.GeneratorDataset(generator_func, ["data"])
print("data1:")
for data in dataset1.create_dict_iterator():
    print(data)

# 生成数据集2
dataset2 = ds.GeneratorDataset(generator_func2, ["data"])
print("data2:")
for data in dataset2.create_dict_iterator():
    print(data)

# 在数据集1上concat数据集2，生成数据集3
dataset3 = dataset1.concat(dataset2)
print("data3:")
for data in dataset3.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，数据集3由数据集1和数据集2行拼接得到，其列数与后两者保持一致，其行数为后两者之和。
# 
# ### map
# 
# map操作将指定的函数作用于数据集的指定列数据，实现数据映射操作。
# 
# 用户可以自定义映射函数，也可以直接使用`c_transforms`或`py_transforms`中的函数针对图像、文本数据进行数据增强。
# 
# ![map](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/source_zh_cn/advanced/dataset/images/op_map.png)
# 
# 下面的样例先构建了一个随机数据集，然后定义了数据翻倍的映射函数并将其作用于数据集，最后对比展示了映射前后的数据结果。

# In[38]:


import numpy as np
import luojianet.dataset as ds

def generator_func():
    """定义生成数据集函数"""
    for i in range(5):
        yield (np.array([i, i+1, i+2]),)

def pyfunc(x):
    """定义对数据的操作"""
    return x*2

# 生成数据集
dataset = ds.GeneratorDataset(generator_func, ["data"])

# 显示上述生成的数据集
for data in dataset.create_dict_iterator():
    print(data)

print("------ after processing ------")

# 对数据集做map操作，操作函数为pyfunc
dataset = dataset.map(operations=pyfunc, input_columns=["data"])

# 显示map操作后的数据集
for data in dataset.create_dict_iterator():
    print(data)


# 从上面的打印结果可以看出，经过map操作，将函数`pyfunc`作用到数据集后，数据集中每个数据都被乘2。
